from tqdm import tqdm
import torch
import sys
import math

from torch import float16
import pandas as pd
import time
# import git

from trainer_utils import  get_dynamic_ranks, accuracy_top_k,count_params


def train(NN, optimizer, train_loader, validation_loader, test_loader, criterion, metric, epochs,
          metric_name='accuracy', device='cpu', count_bias=False, path=None, epoch_status_bar=False,
          fine_tune=False, scheduler=None, save_weights=True, save_progress=False, save_name=''):
    """
    INPUTS:
    NN : neural network with custom layers and methods to optimize with dlra
    train/validation/test_loader : loader for datasets
    criterion : loss function
    metric : metric function
    epochs : number of epochs to train
    metric_name : name of the used metric
    count_bias : flag variable if to count biases in params_count or not
    path : path string for where to save the results
    OUTPUTS:
    running_data : Pandas dataframe with the results of the run
    """

    running_data = pd.DataFrame(data=None, columns=['epoch', 'tau', 'learning_rate', 'train_loss',
                                                    'train_' + metric_name + '(%)', 'validation_loss',
                                                    'validation_' + metric_name + '(%)',
                                                    'top5_validation_' + metric_name + '(%)',
                                                    'test_' + metric_name + '(%)',
                                                    'ranks', '# effective parameters conv',
                                                    'timing batch forward'])

    params_test = count_params(NN)

    file_name = path

    def accuracy(outputs, labels):
        return torch.sum(torch.tensor(torch.argmax(outputs.detach(), axis=1) == labels, dtype=float16))

    metric = accuracy
    batch_size = train_loader.batch_size

    for epoch in tqdm(range(epochs)):


        NN.eval()
        with torch.no_grad():
            k = len(validation_loader)
            batch_size = validation_loader.batch_size
            loss_hist_val = 0.0
            acc_hist_val = 0.0
            acc_top5_hist_val = 0.0
            for i, data in enumerate(validation_loader):  # validation
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = NN(inputs).detach()#.to(device)
                loss_val = criterion(outputs, labels)
                loss_hist_val += float(loss_val.item()) / (k * batch_size)
                acc_hist_val += float(metric(outputs, labels)) / (k * batch_size)
                acc_top5_hist_val  += float(accuracy_top_k(outputs, labels, topk=(5,))[5]) / (k * batch_size)

            if test_loader != None:
                k = len(test_loader)
                loss_hist_test = 0.0
                acc_hist_test = 0.0
                batch_size = test_loader.batch_size
                for i, data in enumerate(test_loader):  # validation
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = NN(inputs).detach()#.to(device)
                    loss_test = criterion(outputs, labels)
                    loss_hist_test += float(loss_test.item()) / (k * batch_size)
                    acc_hist_test += float(metric(outputs, labels)) / (k * batch_size)
            else:
                loss_hist_test = -1
                acc_hist_test = -1


        print(f'epoch {epoch}, acc_val {acc_hist_val}---------------------------------------------')
        loss_hist = 0
        acc_hist = 0
        k = len(train_loader)
        average_batch_time = 0.0

        NN.train()
        for i, data in enumerate(train_loader):  # train
            optimizer.zero_grad()
            start = time.time()
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = NN(inputs)
            loss = criterion(outputs,labels)
            loss.backward()
            loss_hist += float(loss.detach().item()) / (k * batch_size)
            acc_hist += float(metric(outputs, labels)) / (k * batch_size)

            stop = time.time() - start
            average_batch_time += stop / k

            if math.isnan(loss_hist):
                print("Training diverged! Loss is nan")
                exit(1)
                # ValueError("Training diverged! Loss is nan")
            optimizer.step()


        ranks = get_dynamic_ranks(NN.lr_model)
        print('\n')
        for i in range(len(ranks)):
            print(f'rank layer {i} {ranks[i]}')
        print('\n')
        print(
            f'epoch[{epoch}/{epochs}]: loss: {loss_hist:9.4f} | {metric_name}: {acc_hist:9.4f} | val loss: {loss_hist_val:9.4f} | val {metric_name}:{acc_hist_val:9.4f}')
        print('=' * 100)
        
        compression_hyperparam = NN.args.tau
        lr = round(float(optimizer.param_groups[0]['lr']), 4)
        epoch_data = [epoch, compression_hyperparam, lr, round(loss_hist, 3),
                        round(acc_hist * 100, 4), round(loss_hist_val, 3),
                        round(acc_hist_val * 100, 4), round(acc_top5_hist_val * 100, 4),
                        round(acc_hist_test * 100, 4), ranks, params_test,average_batch_time]

        running_data.loc[epoch] = epoch_data
        print(file_name)
        if file_name is not None and (epoch % 1 == 0 or epoch == epochs - 1) and save_progress:
            running_data.to_csv(path + save_name + '.csv')
            try:
                running_data.to_csv(path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv')
            except:
                print(
                    "Tried: " + path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv for additional backup, but:')
                print('drive not found')

        if scheduler is not None:
            scheduler.step(loss_hist)

        if epoch == 0:
            best_val_loss = loss_hist_val

        if loss_hist_val < best_val_loss and save_weights:
            print('save')
            torch.save(NN, path + save_name + '.pt')
            best_val_loss = loss_hist_val
            try:
                torch.save(NN, path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt')
            except:
                print(
                    "Tried: " + path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt for additional backup, but:')
                print('drive not found')

    return running_data